Skip to content

[SYCL] Extract args directly from kernel if we can #18387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: sycl
Choose a base branch
from

Conversation

Pennycook
Copy link
Contributor

@Pennycook Pennycook commented May 9, 2025

In some cases, all values that need to be passed as kernel arguments are stored within the kernel function object, and their offsets can be calculated using the integration header or equivalent built-ins. In such cases, we can therefore set kernel arguments directly without staging via MArgs.

This first attempt is limited to the simplest cases where all kernel arguments are either standard layout types or pointers. It may be possible to extend this approach to cover other cases, but only if some classes are redesigned.

The implementation currently stores some information (e.g., the number of kernel arguments) inside of the handler, because there is no way to access the kernel type within handler::finalize().


Some notes for reviewers:

  • This depends on the new hasSpecialCaptures functionality introduced in [SYCL] Add hasSpecialCaptures() constexpr function #18386, which returns true for kernels that only capture standard layout classes and pointers.

  • There are some seemingly unrelated changes in kernel_desc.hpp and to some of the unit tests. These changes were necessary because hasSpecialCaptures requires getParamDesc to be constexpr. I think this wasn't picked up during [SYCL] Add hasSpecialCaptures() constexpr function #18386 because hasSpecialCaptures wasn't previously being run for every kernel.

  • I'm not really satisfied by the solution of adding a lot more member variables, but it was the best way I could think of to limit the scope of the changes required. Long-term, it would be better to try and move everything (including the complicated cases) to extract everything directly from the lambda, to design an abstraction that unifies the MArgs and MKernelFuncPtr paths, or to find a way to access the required values without them being stored in the handler (e.g., using something like [SYCL] Optimize kernel name based cache lookup #18081).

@Pennycook
Copy link
Contributor Author

This is failing with:

# .---command stdout------------
# | There are new symbols in the new library. It is a non-breaking change. Refer to sycl/doc/developer/ABIPolicyGuide.md for further instructions.
# | The following symbols are new to the object file:
# | 
# | ?prepareForDirectArgumentCopy@handler@_V1@sycl@@AEAAXPEBXHP6A?AUkernel_param_desc_t@detail@23@H@Z@Z
# `-----------------------------

I've never added anything to handler before, and the intent here wasn't to make this function part of the public API. What did I do wrong, and what do I have to change?

@sergey-semenov
Copy link
Contributor

I've never added anything to handler before, and the intent here wasn't to make this function part of the public API. What did I do wrong, and what do I have to change?

The added function is private, so it's not part of the API, but since it's defined in the library source code it becomes part of the ABI. I wouldn't say that's a problem (e.g. it allows applications to benefit from some changes to that function without recompilation), but you can move the definition to the header if you'd like to keep it out of the library ABI boundary.

@@ -198,6 +198,11 @@ class handler_impl {

// Allocation ptr to be freed asynchronously.
void *MFreePtr = nullptr;

// A pointer to a blob of direct kernel arguments, alternative to MArgs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer we use an std::variant to represent that.

Comment on lines 817 to 819
prepareForDirectArgumentCopy((const void *)&KernelFunc,
detail::getKernelNumParams<KernelName>(),
&detail::getKernelParamDesc<KernelName>);
Copy link
Contributor

@aelovikov-intel aelovikov-intel May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong in its current version, I think. Line 768 might have left KernelFunc in moved-out state. How exactly is this beneficial? We've already created a copy in MHostKernel, what extra overhead are you trying to eliminate?

Edit: is #18413 doing exactly that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, this is why I started looking at #18413.

I'm going to work on merging these two PRs today. The end result should be that:

  • We defer extracting arguments from the kernel until we reach handler::finalize()
  • We skip extracting arguments from the kernel if it's a simple case

What I'm trying to eliminate is the overhead of:

  • Creating an MArgs vector
  • Checking that MArgs is sorted
  • Copying the arguments into MArgs before they're passed to UR

};

applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
if (DirectArgs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we delay args extraction (by passing a callback from the user's code maybe) until here, so that non-special arguments would naturally be copied directly by the nature of the modified implementation?

Pennycook added 3 commits May 13, 2025 11:03
Rather than extracting arguments from the lambda when the kernel is enqueued,
store a pointer to the lambda alongside relevant information from the
integration header or compiler builtins.

Storing this information will allow us to defer the extraction of arguments
until we reach handler::finalize(), at which point it may be possible to
set the kernel arguments directly without populating MArgs.

Signed-off-by: John Pennycook <[email protected]>
hasSpecialCaptures requires getParamDesc to be constexpr. Several tests
were previously incompatible with this requirement, but it was only
discovered when trying to call hasSpecialCaptures for each kernel.

Signed-off-by: John Pennycook <[email protected]>
We are currently only able to skip argument extraction in the case
where a lambda has no special captures. We can only detect this
while we have the kernel type name, and must carry it through until
we call handler::finalize().

Signed-off-by: John Pennycook <[email protected]>
Pennycook added 2 commits May 13, 2025 12:17
This shouldn't be necessary, but in my experiments, the compiler does not
optimize the function call away unless it is used in a constexpr if.

Signed-off-by: John Pennycook <[email protected]>
Signed-off-by: John Pennycook <[email protected]>
Pennycook added 2 commits May 13, 2025 14:23
Assuming the alternative could lead to skipping MArgs inconsistently.

Signed-off-by: John Pennycook <[email protected]>
Kernels without special captures might still have an elimination mask.

Signed-off-by: John Pennycook <[email protected]>
@Pennycook Pennycook marked this pull request as ready for review May 13, 2025 15:35
@Pennycook Pennycook requested a review from a team as a code owner May 13, 2025 15:35
@Pennycook Pennycook requested a review from maarquitos14 May 13, 2025 15:35
Comment on lines 801 to 811
if constexpr (detail::hasSpecialCaptures<KernelName>()) {
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), true);
} else {
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), false);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if constexpr (detail::hasSpecialCaptures<KernelName>()) {
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), true);
} else {
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), false);
}
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(),
detail::hasSpecialCaptures<KernelName>());

or if that won't initialize hasSpecialCaptures, maybe the following will?

Suggested change
if constexpr (detail::hasSpecialCaptures<KernelName>()) {
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), true);
} else {
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), false);
}
constexpr bool HasSpecialCapt = detail::hasSpecialCaptures<KernelName>();
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(),
HasSpecialCapt);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first suggestion doesn't work -- that's what I had before, and it didn't get optimized away. I'll try the second form and let you know what happens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the second option seems to work. I've made that change in 54b4621. Thanks!

case kernel_param_kind_t::kind_std_layout: {
int Size = ParamDesc.info;
Adapter->call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
Size, nullptr, ArgPtr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SetArgBasedOnType does

    if (Arg.MPtr) {
      Adapter->call<UrApiKind::urKernelSetArgValue>(
          Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
    } else {
      Adapter->call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
                                                    Arg.MSize, nullptr);
    }

Is there a reason we don't do that here? Is the else-case falling under "special captures"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm honest, I don't really understand how the old MArgs code works. You can see that each argument in MArgs is represented by a pointer to an argument, and this else branch only triggers when that pointer is null.

On the fast path, I'm extracting a standard layout argument directly from the function object. Since it's a standard layout object and not a pointer to one, it can never be null, so I removed the branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is some special case for some local memory accessors. I wonder if these changes can handle that correctly. Let's hope testing is good enough!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be some overloading of the meaning of the kernel parameter kind. The vector contains the arguments after decomposition, which means that there may indeed be some special fields that need handling.

The array I'm working with contains a description of the original arguments, and so anything that's captured as a "standard layout" class really needs to be one -- a local accessor is identified in the array as an accessor, so it will be counted as a special capture.

default:
throw std::runtime_error("Direct kernel argument copy failed.");
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it make sense to add a SetArgBasedOnType overload or variant that the old one can call if the arguments don't fall under the special types? Avoids a little code replication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not with the current design. This is what I was alluding to when I said it would be a good idea to try and unify the MArgs design with what I've done here, though.

The main problem is that we have two different ways to represent what an argument is. MArgs is a vector of detail::ArgDesc objects, but what I'm reading from the integration header is an array of detail::kernel_param_desc_t objects. Submission uses either the vector or the array (and the vector doesn't exist on the fast path) so we can't currently mix and match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you mean. Yeah, I am not sure it's worth trying to repack the arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm replying to the correct thread, but why can't you sink MArgs creation all the way down to applyFuncOnFilteredArgs? This seems to be the place where we process arguments one-by-one anyway, so we'd be able to unify processing without "repacking". Am I missing something obvious?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that could be possible if we had more special-casing in the submission paths, but I didn't want to embark on a redesign that big.

I might have misunderstood the current code, but I convinced myself that the MArgs vector as it's currently used could be populated either by extracting arguments from a lambda, or from the user calling set_arg/set_args to provide the arguments for a sycl::kernel. If we pushed MArgs creation all the way down, I'm not sure how to handle the set_arg/set_args case.

I wrote this primarily to demonstrate that extracting things directly from the kernel function object was 1) possible; and 2) faster. I'd love to make that the default path for everything, but I don't understand the runtime well enough!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe how exactly you're testing performance of this? Either PR's description or maybe a comment near the fast path code. That would help whoever will try to generalize your approach for all submission paths to ensure your "fast path" doesn't regress.

Copy link
Contributor

@aelovikov-intel aelovikov-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, I didn't actually "submit" yesterday.

&(detail::getKernelParamDesc<KernelName>),
detail::getKernelNumParams<KernelName>(),
detail::isKernelESIMD<KernelName>());
// Force hasSpecialCaptures to be evaluated at compile-time.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the sake of performance, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. We can loop over all the kernel arguments at compile-time, so we should. The cost of evaluating this function (at run-time) scales with the number of arguments, and we have to pay that cost for every launch.

default:
throw std::runtime_error("Direct kernel argument copy failed.");
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm replying to the correct thread, but why can't you sink MArgs creation all the way down to applyFuncOnFilteredArgs? This seems to be the place where we process arguments one-by-one anyway, so we'd be able to unify processing without "repacking". Am I missing something obvious?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants